"""
Phase 6: Safe Supply Stress Test
==================================
Headline deliverable: econometric estimates → policy-relevant projections.

6a. Baseline projections for each current safe issuer through 2050
6b. Monte Carlo simulation (coefficient uncertainty → safe supply distributions)
6c. Downgrade cascade scenarios
6d. Output tables and figures

Output: table9_projections.md, table10_safe_supply.md,
        figure3_fan_chart.png, figure4_downgrade_probs.png, phase6_results.csv
"""

import sys
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_asset_cliff")
ROOT_DIR = PROJECT_DIR.parent
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"
FIGURES_DIR = PROJECT_DIR / "output" / "figures"

sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

SAFE_THRESHOLD = 18  # AA-

# Current safe issuers (2024, rating >= AA-)
CURRENT_SAFE = {
    'USA': (20, 'AA+'), 'DEU': (21, 'AAA'), 'GBR': (19, 'AA'),
    'FRA': (19, 'AA'), 'CAN': (21, 'AAA'), 'AUS': (21, 'AAA'),
    'CHE': (21, 'AAA'), 'NLD': (21, 'AAA'), 'AUT': (20, 'AA+'),
    'DNK': (21, 'AAA'), 'FIN': (20, 'AA+'), 'NOR': (21, 'AAA'),
    'SWE': (21, 'AAA'), 'SGP': (21, 'AAA'), 'HKG': (20, 'AA+'),
    'LUX': (21, 'AAA'), 'NZL': (20, 'AA+'), 'BEL': (19, 'AA'),
    'KOR': (19, 'AA'), 'TWN': (20, 'AA+'), 'CZE': (19, 'AA'),
    'KWT': (19, 'AA'), 'QAT': (19, 'AA'), 'ARE': (19, 'AA'),
}


def estimate_rating_model(df):
    """Estimate the OADR spline model for rating_numeric and return model + VCV."""
    controls = ['rgdp_growth', 'inflation', 'kaopen']
    controls = [c for c in controls if c in df.columns]
    x_vars = ['old_dep', 'oadr_spline_20'] + controls

    est = df.dropna(subset=['rating_numeric'] + x_vars).copy()

    model = PanelGLS()
    y = est['rating_numeric'].values
    X = est[x_vars].values
    model.fit(y, X, est['iso3'].values, est['year'].values)
    model.summary(feature_names=x_vars)

    # Compute VCV matrix for Monte Carlo
    import statsmodels.api as sm
    X_const = sm.add_constant(X)

    # Use the GLS estimate's VCV (already computed inside model)
    # Approximate from OLS on transformed data
    n, k = X_const.shape
    resid = y - X_const @ np.concatenate([[model.constant], model.beta])
    sigma2 = np.sum(resid**2) / (n - k)
    XtX_inv = np.linalg.inv(X_const.T @ X_const)
    vcv = sigma2 * XtX_inv  # (k x k) including constant

    return model, x_vars, vcv, est


def project_demographics(df, proj_years):
    """Project OADR for safe issuers using available forward data and trends."""
    projections = {}

    for iso3 in CURRENT_SAFE:
        cdf = df[df['iso3'] == iso3].sort_values('year')
        if len(cdf) == 0:
            continue

        # Get historical OADR
        oadr_data = cdf[['year', 'old_dep']].dropna()
        if len(oadr_data) < 5:
            continue

        # Use oadr_plus20 from fiscal panel if available for forward look
        fwd = cdf[['year', 'oadr_plus10']].dropna()

        # Project using trend from last 15 years
        recent = oadr_data[oadr_data['year'] >= 2005]
        if len(recent) < 5:
            recent = oadr_data.tail(10)

        # Linear trend in OADR
        x_trend = recent['year'].values
        y_trend = recent['old_dep'].values
        if len(x_trend) >= 3:
            slope = np.polyfit(x_trend, y_trend, 1)[0]
        else:
            slope = 0.003  # default 0.3pp/year

        last_year = int(oadr_data['year'].max())
        last_oadr = float(oadr_data[oadr_data['year'] == last_year]['old_dep'].iloc[0])

        # Use oadr_plus10 where available, else trend
        proj = {}
        for yr in proj_years:
            if yr <= last_year:
                row = oadr_data[oadr_data['year'] == yr]
                if len(row) > 0:
                    proj[yr] = float(row['old_dep'].iloc[0])
                    continue

            # Check forward data
            src_yr = yr - 10
            fwd_row = fwd[fwd['year'] == src_yr]
            if len(fwd_row) > 0:
                proj[yr] = float(fwd_row['oadr_plus10'].iloc[0])
            else:
                proj[yr] = last_oadr + slope * (yr - last_year)

        projections[iso3] = proj

    return projections


def get_country_controls(df, iso3):
    """Get latest fiscal/macro controls for a country."""
    cdf = df[df['iso3'] == iso3].sort_values('year')
    last = cdf.dropna(subset=['rgdp_growth']).tail(5)
    if len(last) == 0:
        return {'rgdp_growth': 2.0, 'inflation': 2.0, 'kaopen': 1.0,
                'govt_debt_gdp': 60, 'ngdp_usd': 1000}

    result = {}
    for var in ['rgdp_growth', 'inflation', 'kaopen', 'govt_debt_gdp', 'ngdp_usd']:
        vals = last[var].dropna()
        result[var] = float(vals.mean()) if len(vals) > 0 else 0

    return result


def main():
    print("=" * 70)
    print("PHASE 6: Safe Supply Stress Test")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "cliff_panel.csv")
    print(f"Loaded: {df['iso3'].nunique()} countries, {len(df):,} obs")

    # ── [1] Estimate rating model ──
    print("\n[1] Estimating OADR spline rating model ...")
    model, x_vars, vcv, est_df = estimate_rating_model(df)

    # Full coefficient vector (including constant)
    beta_full = np.concatenate([[model.constant], model.beta])
    n_params = len(beta_full)

    # ── [2] Project demographics ──
    print("\n[2] Projecting demographics for safe issuers ...")
    proj_years = list(range(2024, 2055, 5))
    demo_proj = project_demographics(df, proj_years)
    print(f"  Projected {len(demo_proj)} countries through {proj_years[-1]}")

    # ── [3] Get baseline controls ──
    print("\n[3] Getting country control means ...")
    country_controls = {}
    for iso3 in CURRENT_SAFE:
        country_controls[iso3] = get_country_controls(df, iso3)

    # ── [4] Baseline projections ──
    print("\n[4] Baseline projections ...")
    baseline_rows = []

    for iso3, (current_rating, rating_label) in CURRENT_SAFE.items():
        if iso3 not in demo_proj:
            continue
        ctrl = country_controls[iso3]

        row = {'iso3': iso3, 'current_rating': rating_label,
               'current_rating_num': current_rating}

        for yr in proj_years:
            if yr not in demo_proj[iso3]:
                continue
            oadr = demo_proj[iso3][yr]
            spline_20 = max(0, oadr - 0.20)

            # Build X vector: [1, old_dep, oadr_spline_20, rgdp_growth, inflation, kaopen]
            x_vec = np.array([1.0, oadr, spline_20,
                              ctrl['rgdp_growth'], ctrl['inflation'],
                              ctrl.get('kaopen', 1.0)])

            # Predicted rating
            pred = x_vec @ beta_full
            pred = np.clip(pred, 0, 21)
            row[f'rating_{yr}'] = round(pred, 1)
            row[f'oadr_{yr}'] = round(oadr * 100, 1)

        baseline_rows.append(row)

    baseline_df = pd.DataFrame(baseline_rows)
    print(baseline_df[['iso3', 'current_rating'] +
                       [f'rating_{yr}' for yr in proj_years if f'rating_{yr}' in baseline_df.columns]
                       ].to_string(index=False))

    # ── [5] Monte Carlo simulation ──
    print("\n[5] Monte Carlo simulation (1000 draws) ...")
    N_SIM = 1000
    np.random.seed(42)

    # Draw coefficient vectors from MVN
    try:
        beta_draws = np.random.multivariate_normal(beta_full, vcv, size=N_SIM)
    except np.linalg.LinAlgError:
        # Fallback: use diagonal
        print("  WARNING: VCV not positive definite, using diagonal")
        se_full = np.sqrt(np.abs(np.diag(vcv)))
        beta_draws = np.column_stack([
            np.random.normal(beta_full[j], se_full[j], N_SIM)
            for j in range(n_params)
        ])

    # For each simulation, project ratings and classify safe/not-safe
    safe_supply_sims = {yr: np.zeros(N_SIM) for yr in proj_years}
    total_gdp = sum(country_controls[iso3].get('ngdp_usd', 0)
                    for iso3 in CURRENT_SAFE if iso3 in demo_proj)

    country_safe_prob = {iso3: {yr: 0 for yr in proj_years}
                          for iso3 in CURRENT_SAFE if iso3 in demo_proj}

    for sim in range(N_SIM):
        beta_sim = beta_draws[sim]

        for iso3 in CURRENT_SAFE:
            if iso3 not in demo_proj:
                continue
            ctrl = country_controls[iso3]

            for yr in proj_years:
                if yr not in demo_proj[iso3]:
                    continue
                oadr = demo_proj[iso3][yr]
                spline_20 = max(0, oadr - 0.20)

                x_vec = np.array([1.0, oadr, spline_20,
                                  ctrl['rgdp_growth'], ctrl['inflation'],
                                  ctrl.get('kaopen', 1.0)])

                pred = x_vec @ beta_sim
                # Add noise for rating uncertainty
                pred += np.random.normal(0, model.se[0] * 0.5)
                pred = np.clip(pred, 0, 21)

                if pred >= SAFE_THRESHOLD:
                    debt_usd = ctrl.get('govt_debt_gdp', 60) / 100 * ctrl.get('ngdp_usd', 1000)
                    safe_supply_sims[yr][sim] += debt_usd
                    country_safe_prob[iso3][yr] += 1

    # Convert counts to probabilities
    for iso3 in country_safe_prob:
        for yr in proj_years:
            country_safe_prob[iso3][yr] /= N_SIM

    # Safe supply ratio
    safe_supply_ratio = {yr: safe_supply_sims[yr] / total_gdp if total_gdp > 0
                         else np.zeros(N_SIM) for yr in proj_years}

    print("\n  Safe supply ratio distribution:")
    print(f"  {'Year':>6}  {'Median':>8}  {'P10':>8}  {'P90':>8}  {'Mean N_safe':>10}")
    for yr in proj_years:
        med = np.median(safe_supply_ratio[yr])
        p10 = np.percentile(safe_supply_ratio[yr], 10)
        p90 = np.percentile(safe_supply_ratio[yr], 90)
        n_safe_mean = sum(country_safe_prob[iso3][yr] for iso3 in country_safe_prob)
        print(f"  {yr:>6}  {med:>8.4f}  {p10:>8.4f}  {p90:>8.4f}  {n_safe_mean:>10.1f}")

    # ── [6] Write Table 9: Country projections ──
    print("\n[6] Writing Table 9: Country projections ...")
    headers = ['Country', 'Current'] + [str(yr) for yr in proj_years] + ['P(safe 2050)']
    t9_rows = []
    for iso3 in sorted(CURRENT_SAFE.keys()):
        if iso3 not in demo_proj:
            continue
        row_data = baseline_df[baseline_df['iso3'] == iso3]
        if len(row_data) == 0:
            continue
        rd = row_data.iloc[0]
        row = [iso3, rd['current_rating']]
        for yr in proj_years:
            col = f'rating_{yr}'
            if col in rd and pd.notna(rd[col]):
                row.append(f"{rd[col]:.1f}")
            else:
                row.append("-")
        # P(safe) in 2050
        if iso3 in country_safe_prob and 2049 in country_safe_prob[iso3]:
            row.append(f"{country_safe_prob[iso3][2049]:.2f}")
        elif iso3 in country_safe_prob and 2054 in country_safe_prob[iso3]:
            row.append(f"{country_safe_prob[iso3][2054]:.2f}")
        else:
            row.append("-")
        t9_rows.append(row)

    write_markdown_table(
        TABLES_DIR / "table9_projections.md",
        "Table 9: Projected Sovereign Ratings by Country",
        headers, t9_rows,
        notes="Projected using OADR spline model with UN WPP medium-variant demographics. "
              "Rating on 21-point S&P scale. P(safe) = Monte Carlo probability of remaining AA- or above."
    )

    # ── [7] Write Table 10: Aggregate safe supply ──
    print("\n[7] Writing Table 10: Safe supply distributions ...")
    t10_rows = []
    for yr in proj_years:
        ssr = safe_supply_ratio[yr]
        n_safe_mean = sum(country_safe_prob[iso3][yr] for iso3 in country_safe_prob)
        t10_rows.append([
            str(yr),
            f"{np.median(ssr):.4f}",
            f"{np.percentile(ssr, 10):.4f}",
            f"{np.percentile(ssr, 90):.4f}",
            f"{np.mean(ssr):.4f}",
            f"{np.std(ssr):.4f}",
            f"{n_safe_mean:.1f}",
        ])

    write_markdown_table(
        TABLES_DIR / "table10_safe_supply.md",
        "Table 10: Aggregate Safe Supply Distributions (2024-2054)",
        ["Year", "Median", "P10", "P90", "Mean", "SD", "E[N safe]"],
        t10_rows,
        notes="Safe supply = GDP-weighted debt of countries with rating >= AA-. "
              "1,000 Monte Carlo simulations drawing from coefficient VCV."
    )

    # ── [8] Stress scenarios ──
    print("\n[8] Stress scenarios ...")
    scenarios = {
        'Baseline': {'exp_rev_gap_shock': 0, 'rgdp_shock': 0, 'oadr_mult': 1.0},
        'Fiscal stress (+1 SD gap)': {'exp_rev_gap_shock': 3.0, 'rgdp_shock': 0, 'oadr_mult': 1.0},
        'Rate normalization': {'exp_rev_gap_shock': 0, 'rgdp_shock': -0.5, 'oadr_mult': 1.0},
        'Aging acceleration': {'exp_rev_gap_shock': 0, 'rgdp_shock': 0, 'oadr_mult': 1.15},
    }

    scenario_results = []
    target_yr = proj_years[-1]

    for scenario_name, shocks in scenarios.items():
        n_safe_list = []
        for sim in range(min(N_SIM, 500)):
            beta_sim = beta_draws[sim]
            n_safe = 0

            for iso3 in CURRENT_SAFE:
                if iso3 not in demo_proj or target_yr not in demo_proj[iso3]:
                    continue
                ctrl = country_controls[iso3]
                oadr = demo_proj[iso3][target_yr] * shocks['oadr_mult']
                spline_20 = max(0, oadr - 0.20)

                x_vec = np.array([1.0, oadr, spline_20,
                                  ctrl['rgdp_growth'] + shocks['rgdp_shock'],
                                  ctrl['inflation'],
                                  ctrl.get('kaopen', 1.0)])

                pred = x_vec @ beta_sim + np.random.normal(0, model.se[0] * 0.5)
                pred = np.clip(pred, 0, 21)
                if pred >= SAFE_THRESHOLD:
                    n_safe += 1

            n_safe_list.append(n_safe)

        scenario_results.append({
            'scenario': scenario_name,
            'median_n_safe': np.median(n_safe_list),
            'mean_n_safe': np.mean(n_safe_list),
            'p10_n_safe': np.percentile(n_safe_list, 10),
            'p90_n_safe': np.percentile(n_safe_list, 90),
        })
        print(f"  {scenario_name:35s}  N_safe: median={np.median(n_safe_list):.0f}, "
              f"[{np.percentile(n_safe_list, 10):.0f}, {np.percentile(n_safe_list, 90):.0f}]")

    # ── [9] Figures ──
    print("\n[9] Generating figures ...")

    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt

        # Figure 3: Fan chart
        fig, ax = plt.subplots(figsize=(10, 6))
        years_plot = proj_years
        medians = [np.median(safe_supply_ratio[yr]) for yr in years_plot]
        p10 = [np.percentile(safe_supply_ratio[yr], 10) for yr in years_plot]
        p25 = [np.percentile(safe_supply_ratio[yr], 25) for yr in years_plot]
        p75 = [np.percentile(safe_supply_ratio[yr], 75) for yr in years_plot]
        p90 = [np.percentile(safe_supply_ratio[yr], 90) for yr in years_plot]

        ax.fill_between(years_plot, p10, p90, alpha=0.15, color='steelblue', label='10th-90th')
        ax.fill_between(years_plot, p25, p75, alpha=0.3, color='steelblue', label='25th-75th')
        ax.plot(years_plot, medians, 'b-o', linewidth=2, markersize=6, label='Median')
        ax.axhline(y=medians[0], color='gray', linestyle=':', alpha=0.5)
        ax.set_xlabel('Year', fontsize=12)
        ax.set_ylabel('Safe Supply / Global GDP', fontsize=12)
        ax.set_title('Figure 3: Safe Supply Ratio Projections (Fan Chart)', fontsize=13)
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3)
        fig.tight_layout()
        fig.savefig(FIGURES_DIR / "figure3_fan_chart.png", dpi=150, bbox_inches='tight')
        plt.close()
        print(f"  Saved: {FIGURES_DIR / 'figure3_fan_chart.png'}")

        # Figure 4: Individual country downgrade probability
        fig, ax = plt.subplots(figsize=(12, 7))
        focus_countries = ['USA', 'DEU', 'GBR', 'FRA', 'JPN', 'CAN', 'AUS',
                           'NLD', 'BEL', 'KOR', 'FIN', 'AUT']
        focus_countries = [c for c in focus_countries if c in country_safe_prob]

        for iso3 in focus_countries:
            probs = [country_safe_prob[iso3].get(yr, np.nan) for yr in proj_years]
            # Plot downgrade probability (1 - P(safe))
            dg_probs = [1 - p for p in probs]
            ax.plot(proj_years, dg_probs, '-o', markersize=4, label=iso3)

        ax.set_xlabel('Year', fontsize=12)
        ax.set_ylabel('P(loss of safe status)', fontsize=12)
        ax.set_title('Figure 4: Downgrade Probability Trajectories', fontsize=13)
        ax.legend(fontsize=8, ncol=3)
        ax.grid(True, alpha=0.3)
        ax.set_ylim(-0.05, 1.05)
        fig.tight_layout()
        fig.savefig(FIGURES_DIR / "figure4_downgrade_probs.png", dpi=150, bbox_inches='tight')
        plt.close()
        print(f"  Saved: {FIGURES_DIR / 'figure4_downgrade_probs.png'}")

    except ImportError:
        print("  matplotlib not available, skipping figures")

    # ── [10] Save results ──
    print("\n[10] Saving results ...")
    baseline_df.to_csv(TABLES_DIR / "phase6_baseline_projections.csv", index=False)

    prob_rows = []
    for iso3 in country_safe_prob:
        for yr in proj_years:
            prob_rows.append({'iso3': iso3, 'year': yr,
                              'p_safe': country_safe_prob[iso3][yr]})
    pd.DataFrame(prob_rows).to_csv(TABLES_DIR / "phase6_safe_probs.csv", index=False)
    pd.DataFrame(scenario_results).to_csv(TABLES_DIR / "phase6_scenarios.csv", index=False)

    print(f"  Saved phase 6 results to {TABLES_DIR}")

    print("\n" + "=" * 70)
    print("Phase 6 complete.")
    print("=" * 70)


def write_markdown_table(path, title, headers, rows, notes=None):
    lines = [f"### {title}", ""]
    lines.append("| " + " | ".join(headers) + " |")
    lines.append("|" + "|".join(["--:" if i > 0 else ":--" for i in range(len(headers))]) + "|")
    for row in rows:
        lines.append("| " + " | ".join(str(c) for c in row) + " |")
    if notes:
        lines.append("")
        lines.append(f"*{notes}*")
    lines.append("")
    path.write_text("\n".join(lines), encoding="utf-8")
    print(f"  Saved: {path}")


if __name__ == "__main__":
    main()
